from math import sqrt
from constant import *
from Agents import *
from Ball import ball

class WorldModel(Ui_MainWindow,QMainWindow):
    def __init__(self):
        super().__init__()
        self.setupUi(self)
        self.initUI()
        self.a=0
        self.w=0
        self.s=0
        self.d=0
        self.j=0
        self.tx=1000
        self.ty=1000
        self.agent_group=[]
        self.add_ball()
        self.agent1=self.add_agent(1)
        self.agent2=self.add_agent(2)
        self.timestep=sleep_time
        self.timer=QTimer()
        self.game_state=P1_START
        self.curstate=self.get_current_state()
        self.timer.timeout.connect(self.update_world)
        self.freezetime=FREEAZ
        self.unfreezed=True
        self.p2sc=0
        self.p1sc=0
        # self.timer.start()
        self.isover=False
                

    def initUI(self):
        self.show()
        
    def add_agent(self,type=1):
        agent=human_agent(self,type=type)
        self.agent_group.append(agent)
        return agent
        
    def add_ball(self):
        self.ball=ball(self)
        self.agent_group.append(self.ball)
        # self.agent1=self.ball
        pass

    def keyPressEvent(self, event):

        if(event.key() == Qt.Key_W):  
            self.agent2.set_w(1)

        if(event.key() == Qt.Key_Up):  
            self.agent1.set_w(1)

            
        if(event.key() == Qt.Key_S):
            self.agent2.set_s(1)
        
        if(event.key() == Qt.Key_Down):
            self.agent1.set_s(1)

        if(event.key() == Qt.Key_A):
            self.agent2.set_a(1)

        if(event.key() == Qt.Key_Left):
            self.agent1.set_a(1)

        
        if(event.key() == Qt.Key_D):
            self.agent2.set_d(1)

        if(event.key() == Qt.Key_Right):
            self.agent1.set_d(1)


        
        if(event.key() == Qt.Key_J):
            self.agent2.set_j(1)
        
        if(event.key() == Qt.Key_Enter):
            self.agent1.set_j(1)

            
    def keyReleaseEvent(self,event):
        if(event.key() == Qt.Key_W):  
            self.agent2.set_w(0)
            
        if(event.key() == Qt.Key_S):
            self.agent2.set_s(0)

        if(event.key() == Qt.Key_A):
            self.agent2.set_a(0)

        
        if(event.key() == Qt.Key_D):
            self.agent2.set_d(0)

        
        if(event.key() == Qt.Key_J):
            self.agent2.set_j(0)

        #
        if(event.key() == Qt.Key_Up):  
            self.agent1.set_w(0)
            
        if(event.key() == Qt.Key_Down):
            self.agent1.set_s(0)

        if(event.key() == Qt.Key_Left):
            self.agent1.set_a(0)

        
        if(event.key() == Qt.Key_Right):
            self.agent1.set_d(0)

        
        if(event.key() == Qt.Key_Enter):
            self.agent1.set_j(0)
            print('ok')

    
    
    def get_action_space(self):
        return 18

    def get_state_space(self):

        return 9
    
    
    def get_current_state(self):
        state=[self.game_state]
        state.append(self.agent1.x())
        state.append(self.agent1.y())
        state.append(self.agent2.x())
        state.append(self.agent2.y())
        state.append(self.ball.x())
        state.append(self.ball.y())
        state.append(self.ball.tx)
        state.append(self.ball.ty)
        return np.array(state)

    
    def fresh_epoch(self):
        self.game_state=P1_START
        self.p1sc=0
        self.p2sc=0
        self.r_sc_label.setText(str(self.p1sc))
        self.l_score_label.setText(str(self.p2sc))

        self.agent1.reset()
        self.agent2.reset()
        if self.game_state==P2_START:
            self.ball.breset(self.agent2.x(),self.agent2.y())
        if self.game_state==P1_START:
            self.ball.breset(self.agent1.x(),self.agent1.y())
        self.a=0
        self.w=0
        self.s=0
        self.d=0
        self.j=0
        self.tx=right_x
        self.ty=bottom
        self.curstate=self.get_current_state()

    def reset(self):

        if self.game_state!=P1_START and self.game_state!=P2_START:
            if self.ball.y()<mid_up:
                self.game_state=P1_START
                self.p1sc+=1
                self.p1sc%=200
            if self.ball.y()>mid_down:
                self.game_state=P2_START
                self.p2sc+=1
                self.p2sc%=200
        
        self.r_sc_label.setText(str(self.p1sc))
        self.l_score_label.setText(str(self.p2sc))

        self.agent1.reset()
        self.agent2.reset()
        if self.game_state==P2_START:
            self.ball.breset(self.agent2.x(),self.agent2.y())
        if self.game_state==P1_START:
            self.ball.breset(self.agent1.x(),self.agent1.y())
        self.a=0
        self.w=0
        self.s=0
        self.d=0
        self.j=0
        self.tx=right_x
        self.ty=bottom
        self.curstate=self.get_current_state()
        self.isover=False
        pass

    
    def get_game_state(self):
        return self.game_state

    def get_ball_target(self):
        return self.ball.tx,self.ball.ty

    def get_ball_pos(self):
        return self.ball.x(),self.ball.y()

    def is_ball_hitable(self,x,y):
        dis=get_dis(x,y,self.ball.x(),self.ball.y())
        if dis<=HIT_RANGE:
            return True
        else:
            return False

    def change_game_state(self,player,x,y):

        if not self.game_state==P2_START and not self.game_state==P1_START:
            if self.ball.ty<mid_up and player==1:
                return
            if   self.ball.ty>mid_down and player==2:
                return
        if self.game_state==P1_START:
            self.game_state=PLAY_2_ON
        elif self.game_state==P2_START:
            self.game_state=PLAY_1_ON
        else:
            self.game_state+=1
            self.game_state%=2
        self.ball.set_t(x,y)

    
    def cal_pt_reward(self,x0,y0,x1,y1,tx,ty):

        dis1=get_dis(x0,y0,tx,ty)
        dis2=get_dis(x1,y1,tx,ty)
        r=move_r*(dis1-dis2)/step
        return r
    
    
    def cal_reward(self,state,player):
        reward=0.0
        bx1=self.curstate[5]
        by1=self.curstate[6]
        bx2=state[5]
        by2=state[6]
        tx1=self.curstate[7]
        ty1=self.curstate[8]
        tx2=self.curstate[7]
        ty2=self.curstate[8]
        if (not self.game_state==P1_START) and (not self.game_state==P2_START):
            if player==1:
                if (((bx2<=(left_x-10) or bx2>=(right_x+10)) and by2>=(mid_down)) or by2>=(bottom+50)):
                    reward-=sc_r
                    self.isover=True
                if ((bx2<=left_x-10 or bx2>=right_x+10) and by2<=mid_up) or by2<=up-50:
                    reward+=sc_r
                    self.isover=True

            # if player==2:
                
            #     if ((bx2<=left_x-15 or bx2>=right_x+15) and by2<=mid_up) or by2<=up-15:
            #         reward-=sc_r
            #         self.isover=True

            #     if ((bx2<=left_x-15 or bx2>=right_x+15) and by2>=mid_down) or by2>=bottom+15:
                
            #         reward+=sc_r
            #         self.isover=True


        if player==1 and self.get_game_state()==PLAY_1_ON:
            
            ax1=self.curstate[1]
            ay1=self.curstate[2]
            ax2=state[1]
            ay2=state[2]

            reward+=self.cal_pt_reward(ax1,ay1,ax2,ay2,tx1,ty1)
        
        if player==1 and self.get_game_state()==PLAY_2_ON:
            
            ax1=self.curstate[1]
            ay1=self.curstate[2]
            ax2=state[1]
            ay2=state[2]

            reward+=self.cal_pt_reward(ax1,ay1,ax2,ay2,DOWNMID[0],DOWNMID[1])

            
            

        # if player==2 and state[8]<mid_down:
        #     ax1=self.curstate[3]
        #     ay1=self.curstate[4]
        #     ax2=state[3]
        #     ay2=state[4]
        #     # if get_dis(tx1,ty1,ax1,ay1)>get_dis(tx2,ty2,ax2,ay2):
        #     #     reward+=0.1

        # # if reward!=0 and player==1:
        # #     print('get caled reward %d',int(reward))

        return reward

    
    def cal_hit_reward(self,x1,y1,x2,y2,tx,ty):
        try:
            a=get_dis(x1,y1,x2,y2)
            b=get_dis(x1,y1,tx,ty)
            c=get_dis(x2,y2,tx,ty)
            x=(a*a+c*c-b*b)/(2*c)
            if x<=0:
                t=math.asin((c-x)/b)-math.asin(-x/a)
                return hit_r*0.4+hit_r*0.6*math.sin(t)
            t1=(math.asin(x/a))
            # print(a,b,c,x,'degree:',t1)
            
            t2=(math.asin((c-x)/a))
            
            if (math.degrees(math.asin(x/a))+math.degrees(math.asin((c-x)/a)))>90:
                
                return hit_r*0.4+hit_r*0.6*((c-x)/b)*1.2
            
            else:
                return hit_r*0.4+hit_r*0.6*math.sin(t1+t2)
        except:
            # print(x1,y1,x2,y2,tx,ty,'dis c:',c,'dis a:',a,'dis x:',x)
            return hit_r*0.4




    def submit_act(self,act,x0,y0,x,y,player):
        reward=0.0
        state=self.get_current_state()
        if act in attack_acts:
            dis=get_dis(x0,y0,self.ball.x(),self.ball.y())
            if dis<=HIT_RANGE:
                reward+=self.cal_hit_reward(self.curstate[1],self.curstate[2],self.curstate[3],self.curstate[4],self.curstate[7],self.curstate[8])
                self.change_game_state(player,x,y)
                

        reward+=self.cal_reward(state,player)
        


        self.curstate=state

        return self.curstate,reward
            

    
            


    def train(self):

        raise NotImplementedError

    def update_world(self): 
        if self.isover:
            # print('p1 set over',self.ball.x(),self.ball.y())
            self.reset()
            self.update()
            return 
        self.agent2.set_action()
        self.agent2.execute_action()
        if self.isover:
            # print('p2 set over',self.ball.x(),self.ball.y())   
            self.reset()
            self.update()
            
            return
            
        self.ball.execute_action()
        self.update()

        
